use super::{Allocator, Boxed, Error, NullAlloc};
use core::alloc::{GlobalAlloc, Layout};
use core::cell::Cell;
use core::mem::size_of;
use core::ptr::{self, NonNull};

///
/// Pool support Send, don't support Sync
/// Only used in single thread
///
#[repr(C)]
pub struct Pool<'a> {
    buf: &'a mut [u8],
    pos: Cell<usize>,
}

unsafe impl Send for Pool<'static> {}

pub type BoxedPool<'a> = Boxed<'a, Pool<'a>, NullAlloc>;

impl<'a> Pool<'a> {
    pub fn new(buf: &'a mut [u8]) -> Self {
        Self {
            buf,
            pos: Cell::new(0),
        }
    }
    pub fn new_boxed(buf: &'a mut [u8]) -> Result<Boxed<'a, Self, NullAlloc>, Error> {
        const SIZE: usize = size_of::<Pool>();
        if SIZE > buf.len() {
            return Err(Error::default());
        };

        let mut ptr = NonNull::from(&buf[0..]).cast::<Pool>();
        unsafe {
            ptr::write(
                ptr.as_mut(),
                Pool {
                    buf,
                    pos: Cell::new(SIZE),
                },
            );
            Ok(Boxed::from_with(NullAlloc, ptr, Layout::new::<Self>()))
        }
    }

    pub fn reset_boxed(boxed: Boxed<'a, Self, NullAlloc>) -> Boxed<'a, Self, NullAlloc> {
        let buf: &mut [u8] = Boxed::leak(boxed).buf;
        Pool::new_boxed(buf).unwrap()
    }
}

impl Pool<'_> {
    /// # Safety
    /// 使用者保证raw一定是leak返回的内存地址
    pub unsafe fn from_raw(raw: *mut Self) -> Boxed<'static, Self, NullAlloc> {
        Boxed::from_with(
            NullAlloc,
            NonNull::new_unchecked(raw),
            Layout::new::<Self>(),
        )
    }

    pub fn reset(self) -> Self {
        Pool::new(self.buf)
    }

    fn get_pos(&self) -> usize {
        self.pos.get()
    }

    fn set_pos(&self, pos: usize) {
        self.pos.set(pos);
    }

    fn get_buf(&self, pos: usize, size: usize) -> NonNull<[u8]> {
        NonNull::from(&self.buf[pos..(pos + size)])
    }

    fn alloc_buf<F>(&self, align: usize, size: usize, f: F) -> Result<NonNull<[u8]>, Error>
    where
        F: FnOnce(NonNull<[u8]>) -> Result<NonNull<[u8]>, Error>,
    {
        let cur = self.get_pos();
        let pos = (cur + align - 1) & !(align - 1);
        if pos < self.buf.len() && size <= self.buf.len() - pos {
            self.set_pos(pos + size);
            match f(self.get_buf(pos, size)) {
                Ok(buf) => Ok(buf),
                Err(error) => {
                    if self.get_pos() == pos + size {
                        self.set_pos(cur);
                    }
                    Err(error)
                }
            }
        } else {
            Err(Error::default())
        }
    }

    unsafe fn allocate_buf<F>(&self, layout: Layout, f: F) -> Result<NonNull<[u8]>, Error>
    where
        F: FnOnce(NonNull<[u8]>) -> Result<(), Error>,
    {
        self.alloc_buf(layout.align(), layout.size(), |ptr| f(ptr).map(|_| ptr))
    }
}

unsafe impl Allocator for &Pool<'_> {
    unsafe fn alloc_buf<F>(&self, layout: Layout, f: F) -> Result<NonNull<[u8]>, Error>
    where
        F: FnOnce(NonNull<[u8]>) -> Result<(), Error>,
    {
        Pool::allocate_buf(self, layout, f)
    }

    unsafe fn deallocate(&self, _ptr: NonNull<[u8]>, _layout: Layout) {}
}

unsafe impl GlobalAlloc for &Pool<'_> {
    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
        match Allocator::alloc_buf(self, layout, |_| Ok(())) {
            Ok(ptr) => ptr.cast::<u8>().as_ptr(),
            Err(_) => ptr::null_mut(),
        }
    }
    unsafe fn dealloc(&self, _ptr: *mut u8, _layout: Layout) {}
}

unsafe impl Allocator for &mut Pool<'_> {
    unsafe fn alloc_buf<F>(&self, layout: Layout, f: F) -> Result<NonNull<[u8]>, Error>
    where
        F: FnOnce(NonNull<[u8]>) -> Result<(), Error>,
    {
        Pool::allocate_buf(self, layout, f)
    }

    unsafe fn deallocate(&self, _ptr: NonNull<[u8]>, _layout: Layout) {}
}

unsafe impl GlobalAlloc for &mut Pool<'_> {
    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
        match Allocator::alloc_buf(self, layout, |_| Ok(())) {
            Ok(ptr) => ptr.cast::<u8>().as_ptr(),
            Err(_) => ptr::null_mut(),
        }
    }
    unsafe fn dealloc(&self, _ptr: *mut u8, _layout: Layout) {}
}

#[cfg(test)]
mod test {
    use crate::{Boxed, BufPool, Error};
    use core::mem::{align_of, size_of, MaybeUninit};
    use core::ptr;

    struct Foo {
        val: i32,
    }

    #[test]
    fn test_t() {
        let mut buf = [0_u8; 10 * core::mem::size_of::<Foo>() + 1];
        let addr = &buf as *const _ as *const u8 as usize;
        let pool = BufPool::new(&mut buf);
        for n in 0..10 {
            let foo = Boxed::new_in(&pool, Foo { val: n as i32 }).unwrap();
            assert_eq!(foo.val, n);
            let pos = &foo.val as *const _ as *const u8 as usize;
            assert_eq!(addr + n as usize * core::mem::size_of::<Foo>(), pos);
        }
        let foo = Boxed::new_in(&pool, Foo { val: -1 });
        assert!(foo.is_err());
    }

    #[test]
    fn test_t_array() {
        let mut buf = [0_u8; 100];
        let pool = BufPool::new(&mut buf);
        let foo = Boxed::new_slice_then_in(&pool, 10, |n, obj: &mut MaybeUninit<Foo>| {
            obj.write(Foo { val: n as i32 });
            Ok(())
        })
        .unwrap();
        foo.iter().fold(0, |n, obj| {
            assert_eq!(n, obj.val);
            n + 1
        });
        assert_eq!(foo[1].val, 1);
    }

    #[test]
    fn test_t_error() {
        let mut buf = [0_u8; 4];
        let pool = BufPool::new(&mut buf);
        let foo = Boxed::new_then_in(&pool, |_: &mut MaybeUninit<Foo>| {
            Err(super::Error::default())
        });
        assert!(foo.is_err());
        let foo = Boxed::new_in(&pool, Foo { val: 100 }).unwrap();
        assert_eq!(foo.val, 100);
        let foo = Boxed::new_then_in(&pool, |_: &mut MaybeUninit<Foo>| Ok(()));
        assert!(foo.is_err());
    }

    #[test]
    fn test_t_array_error() {
        struct Foo {}
        static mut COUNT: i32 = 0;
        impl Drop for Foo {
            fn drop(&mut self) {
                unsafe {
                    COUNT += 1;
                }
            }
        }
        let mut buf = [0_u8; 100];
        let pool = BufPool::new(&mut buf);
        unsafe {
            COUNT = 0;
        }
        let _foo = Boxed::new_slice_then_in(&pool, 10, |n, _: &mut MaybeUninit<Foo>| {
            if n < 9 {
                Ok(())
            } else {
                Err(Error::default())
            }
        });
        unsafe {
            assert_eq!(COUNT, 9);
        }
    }

    #[test]
    fn test_new_boxed() {
        let mut buf = [0_u8; 100];
        let addr = &buf as *const _ as usize;
        let pool = BufPool::new_boxed(&mut buf);
        assert!(pool.is_ok());
        let pool = pool.unwrap();
        let foo = Boxed::new_in(&*pool, Foo { val: 0 }).unwrap();
        assert_eq!(foo.val, 0);
        assert_eq!(
            size_of::<BufPool>(),
            foo.as_ref() as *const _ as usize - pool.as_ref() as *const _ as usize
        );
        assert_eq!(pool.as_ref() as *const _ as usize, addr);
    }
    #[test]
    fn test_reset_boxed() {
        let mut buf = [0_u8; 4096];
        let mut pool = BufPool::new_boxed(&mut buf).unwrap();
        let addr1;
        {
            let val_u32 = Boxed::uninit_slice_in::<u32>(&*pool, 1000).unwrap();
            addr1 = val_u32.as_ptr() as *const u8 as usize;
        }
        pool = BufPool::reset_boxed(pool);
        {
            let val_u32 = Boxed::uninit_slice_in::<u32>(&*pool, 1000).unwrap();
            assert_eq!(addr1, val_u32.as_ptr() as *const u8 as usize);
            let val_u32 = Boxed::uninit_slice_in::<u32>(&*pool, 1000);
            assert!(val_u32.is_err());
        }
        pool = BufPool::reset_boxed(pool);
        let val_u32 = Boxed::uninit_slice_in::<u32>(&*pool, 1000).unwrap();
        assert_eq!(addr1, val_u32.as_ptr() as *const u8 as usize);
    }

    #[test]
    fn test_alloc_in_ctor() {
        struct Foo<'a> {
            val1: i32,
            val2: crate::Boxed<'a, i32, &'a BufPool<'a>>,
        }
        let mut buf = [0_u8; 4096];
        let pool = BufPool::new(&mut buf);

        let foo = Boxed::new_then_in(&pool, |foo: &mut MaybeUninit<Foo>| {
            let foo = foo.as_mut_ptr();
            unsafe {
                ptr::addr_of_mut!((*foo).val1).write(99);
                ptr::addr_of_mut!((*foo).val2).write(Boxed::new_in(&pool, 100_i32)?);
            }
            Ok(())
        });
        assert!(foo.is_ok());
        let foo = foo.unwrap();

        assert_eq!(foo.val1, 99);
        assert_eq!(*foo.val2, 100);

        let bar = Boxed::new_then_in(&pool, |foo: &mut MaybeUninit<Foo>| {
            let foo = foo.as_mut_ptr();
            unsafe {
                ptr::addr_of_mut!((*foo).val2)
                    .write(Boxed::new_then_in(&pool, |_: &mut MaybeUninit<i32>| {
                        Err(Error::default())
                    })?);
            }
            Ok(())
        });
        assert!(bar.is_err());

        let bar = Boxed::new_then_in(&pool, |bar: &mut MaybeUninit<Foo>| {
            unsafe {
                ptr::addr_of_mut!((*bar.as_mut_ptr()).val2).write(Boxed::new_in(&pool, 100)?);
            }
            Ok(())
        })
        .unwrap();
        assert_eq!(
            &foo.val1 as *const _ as *const u8 as usize + size_of::<Foo>() + align_of::<Foo>(),
            &bar.val1 as *const _ as *const u8 as usize
        );
    }
}
